Skip to content

[kv_offload+HMA][4/N]: Support sliding window lookup#36645

Merged
orozery merged 3 commits intovllm-project:mainfrom
orozery:kv-offload-sliding-window-lookup
Apr 20, 2026
Merged

[kv_offload+HMA][4/N]: Support sliding window lookup#36645
orozery merged 3 commits intovllm-project:mainfrom
orozery:kv-offload-sliding-window-lookup

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Mar 10, 2026

This PR adds lookup support for sliding window attention groups.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for sliding window lookup in the KV offloading mechanism. This is achieved by adding a new sliding_window_lookup method to the OffloadingManager interface and implementing it in LRUOffloadingManager and ARCOffloadingManager. The existing lookup method has been renamed to maximal_prefix_lookup for clarity. The changes are well-tested.

My main feedback is regarding code duplication in the implementations of sliding_window_lookup and maximal_prefix_lookup across the different manager classes. I've suggested a refactoring to improve maintainability.

Comment thread vllm/v1/kv_offload/lru_manager.py Outdated
Copy link
Copy Markdown
Contributor

@gambletan gambletan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work adding the sliding_window_lookup method to the offloading managers. The API design and test coverage look solid. A few notes:

  1. Algorithmic concern in sliding_window_lookup (both lru_manager.py and arc_manager.py): The implementation scans right-to-left and returns the first window of sliding_window_size consecutive hits found from the end. The docstring says "finds the maximal ending position," which implies we want the rightmost (largest index) window. Since you scan from right to left and return on the first match, this correctly finds the rightmost qualifying window. However, the fallback behavior (returning consecutive_hits when no full window is found) returns the count of consecutive hits at the beginning of the array (since the loop finishes at index 0). This is consistent with maximal_prefix_lookup behavior, which is presumably intentional — worth a brief inline comment explaining this dual semantics.

  2. Test assertion sliding_window_lookup(to_hashes([2, 1, 3, 4, 5]), 4) == 1: This is a great edge case. With blocks [2,3,4,5] stored and window size 4, we check [2,1,3,4,5] — block 1 is not stored, so no window of 4 exists. The fallback prefix is just block 2 (1 consecutive hit from the start). Clear and correct.

  3. Sequence vs Iterable: Good decision to require Sequence[BlockHash] for sliding_window_lookup since you need len() and indexing, while keeping Iterable[BlockHash] for maximal_prefix_lookup which only needs forward iteration.

  4. Missing sliding_window_lookup test for ARC manager: The ARC manager tests only update lookupmaximal_prefix_lookup but don't add any sliding_window_lookup assertions. Since the ARC implementation has the same logic, it would be good to add at least one or two sliding_window_lookup assertions in the ARC test functions to ensure the T1/T2 dual-cache lookup works correctly with the sliding window.

@orozery orozery force-pushed the kv-offload-sliding-window-lookup branch from d798ce7 to d661471 Compare March 24, 2026 07:29
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 27, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 27, 2026
@orozery orozery force-pushed the kv-offload-sliding-window-lookup branch from d661471 to ad9bba7 Compare March 29, 2026 04:50
@mergify mergify Bot removed the needs-rebase label Mar 29, 2026
@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 7, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 7, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 7, 2026
@orozery orozery force-pushed the kv-offload-sliding-window-lookup branch from ad9bba7 to 58604e9 Compare April 7, 2026 13:07
@mergify mergify Bot removed the needs-rebase label Apr 7, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 8, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 8, 2026
@orozery orozery force-pushed the kv-offload-sliding-window-lookup branch from 58604e9 to cf38f07 Compare April 9, 2026 09:32
@mergify mergify Bot removed the needs-rebase label Apr 9, 2026
@orozery orozery force-pushed the kv-offload-sliding-window-lookup branch from cf38f07 to a49aad8 Compare April 9, 2026 11:55
Comment on lines +148 to +150
def _sliding_window_lookup(
self, keys: Sequence[OffloadKey], sliding_window_size: int
) -> int | None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are we using this function..?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be used in a follow-up PR, when we actually parse SlidingWindowSpec / MambaSpec.

@@ -181,7 +219,7 @@ def get_num_new_matched_tokens(
return 0, False

start_block_idx = num_computed_tokens // group_config.offloaded_block_size
hits = self.manager.lookup(offload_keys[start_block_idx:])
hits = self._maximal_prefix_lookup(offload_keys[start_block_idx:])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this may look trivial, but I think we should comment here why we're doing a maximal prefix lookup

Comment on lines +138 to +139
if result is None:
defer_lookup = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to loop through all keys here if we just return None in this case?
also qq, None here means the key is not present at all?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None means "retry lookup later".
This is also the reason we loop through all keys: To allow backend to kick-off async lookups for all keys.

@orozery orozery force-pushed the kv-offload-sliding-window-lookup branch from a49aad8 to 322d2a2 Compare April 10, 2026 15:17
This commit adds lookup support for sliding window attention groups.
We move the lookup logic from OffloadingManager to OffloadingConnectorScheduler.
This allows a simpler API for OffloadingManager.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the kv-offload-sliding-window-lookup branch from 322d2a2 to 3690b95 Compare April 20, 2026 06:42
@orozery orozery requested a review from xuechendi as a code owner April 20, 2026 06:42
@orozery orozery merged commit f774ba0 into vllm-project:main Apr 20, 2026
61 checks passed
bnellnm pushed a commit to neuralmagic/vllm that referenced this pull request Apr 20, 2026
)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
iboiko-habana pushed a commit to vllm-project/vllm-gaudi that referenced this pull request Apr 29, 2026
…stream breakages: NIXL connector, TpKVTopology rename, MoE refactor, transformers v5 (#1377)

## Summary

Compatibility fixes for vLLM bump to `3975eb6de6`. Addresses breakages
from multiple upstream PRs affecting NIXL connectors, MoE runner
refactor, offloading tests, Qwen3 MoE models, and transformers v5
upgrade.

## Root Cause

1. **NIXL import gate** — Upstream PR
vllm-project/vllm#39529 (commit `cc3993b05d`)
moved NIXL imports to `vllm/distributed/nixl_utils.py` and changed the
platform gate from `if not is_rocm()` to `if is_cuda()`. HPU is neither
CUDA nor ROCm, so it falls into the `else` branch → tries `rixl._api`
(ROCm-only) → fails → `NixlWrapper = None` → `RuntimeError("NIXL is not
available")`.

2. **TpKVTopology rename** — Same upstream PR #39529 unified
`TpKVTopology` + `HeteroTPTransferConfig` into `TransferTopology`,
breaking vllm-gaudi NIXL connector imports.

3. **Offloading tests** — Upstream PR
vllm-project/vllm#36645 changed
`OffloadingManager.lookup()` API.

4. **MoE runner refactor** — Upstream PR
vllm-project/vllm#35949 (commit `726efe177b`)
moved reduce logic into `MoERunnerBase`, removing `reduce_results`,
renaming `forward_dispatch` → `_forward_dispatch`, `forward_entry` →
`_forward_entry`, `_maybe_reduce_output` → `_maybe_reduce_final_output`.
Follow-up PR moved `MoERunnerBase` and `get_layer_from_name` to
`moe_runner_base.py`.

5. **Qwen3 MoE** — `SharedFusedMoE` returns a combined tensor (not a
tuple), and MoE runner now handles TP reduction internally, causing
double-reduce in `qwen3_moe.py` / `qwen3_next.py`.

6. **Transformers v5 — granite tokenizer** — Upstream PR
vllm-project/vllm#30566 updated transformers to
allow v5. GPT2Tokenizer in v5 now respects `add_bos_token=True`
(silently ignored in v4), causing degenerate outputs and 0.0 GSM8K
accuracy on granite models.

7. **Transformers v5.6.x — DeepSeek-V2-Lite tokenizer** — In
transformers v5.6.x, `LlamaTokenizerFast` was unified into
`LlamaTokenizer`, which does not apply the ByteLevel BPE decoder
declared in `tokenizer.json`. DeepSeek-V2-Lite-Chat's tokenizer decoding
strips all spaces (Ġ chars not converted back), producing garbled output
and 0.0 accuracy on GSM8K. Fixed natively in transformers v5.7.0.

## Fix

1. **NIXL import patch**: Add `patch_nixl_utils_for_hpu()` in
`register_utils()` to monkey-patch `vllm.distributed.nixl_utils` —
imports from `nixl._api` instead of `rixl._api` on HPU. Update
`hetero_hpu_nixl_connector.py` to import from
`vllm.distributed.nixl_utils` instead of hardcoded `nixl._api`.
2. **TpKVTopology → TransferTopology**: Rename in NIXL connector imports
and monkey-patches.
3. **Offloading tests**: Replace `runner.manager.lookup.return_value`
with `connector_scheduler._maximal_prefix_lookup`.
4. **MoE refactor**: Update imports (`MoERunnerBase` from
`moe_runner_base`), method names (`_forward_dispatch`, `_forward_entry`,
`_maybe_reduce_final_output`), remove dead `reduce_results` /
`reduce_output()`.
5. **Qwen3 MoE**: Remove incorrect shared_expert tuple indexing and
double TP reduction.
6. **Transformers v5 — granite**: Remove hardcoded `add_bos_token=True`
from lm-eval model_args to fix GSM8K accuracy regression.
7. **Transformers v5.6.x — DeepSeek-V2-Lite**: Exclude `transformers
5.6.*` in `requirements.txt` to prevent installation of versions with
broken ByteLevel BPE tokenizer decoding. Verified on Gaudi2: gsm8k
accuracy 0.65 (expected 0.66, within tolerance) with transformers 5.7.0.

---------

Signed-off-by: Paweł Olejniczak <pawelx.olejniczak@intel.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Signed-off-by: Adrian <info@zzit.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants